/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark

import java.util.{Timer, TimerTask}
import java.util.concurrent.ConcurrentHashMap
import java.util.function.Consumer

import scala.collection.mutable.{ArrayBuffer, HashSet}

import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted}

/**
 * For each barrier stage attempt, only at most one barrier() call can be active at any time, thus
 * we can use (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is
 * from.
 */
private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) {
  override def toString: String = s"Stage $stageId (Attempt $stageAttemptId)"
}

/**
 * A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync
 * request is generated by `BarrierTaskContext.barrier()`, and identified by
 * stageId + stageAttemptId + barrierEpoch. Reply all the blocking global sync requests upon
 * all the requests for a group of `barrier()` calls are received. If the coordinator is unable to
 * collect enough global sync requests within a configured time, fail all the requests and return
 * an Exception with timeout message.
 */
private[spark] class BarrierCoordinator(
    timeoutInSecs: Long,
    listenerBus: LiveListenerBus,
    override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging {

  // TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to
  // fetch result, we shall fix the issue.
  private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer")

  // Listen to StageCompleted event, clear corresponding ContextBarrierState.
  private val listener = new SparkListener {
    override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
      val stageInfo = stageCompleted.stageInfo
      val barrierId = ContextBarrierId(stageInfo.stageId, stageInfo.attemptNumber)
      // Clear ContextBarrierState from a finished stage attempt.
      cleanupBarrierStage(barrierId)
    }
  }

  // Record all active stage attempts that make barrier() call(s), and the corresponding internal
  // state.
  private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState]

  override def onStart(): Unit = {
    super.onStart()
    listenerBus.addToStatusQueue(listener)
  }

  override def onStop(): Unit = {
    try {
      states.forEachValue(1, clearStateConsumer)
      states.clear()
      listenerBus.removeListener(listener)
    } finally {
      super.onStop()
    }
  }

  /**
   * Provide the current state of a barrier() call. A state is created when a new stage attempt
   * sends out a barrier() call, and recycled on stage completed.
   *
   * @param barrierId Identifier of the barrier stage that make a barrier() call.
   * @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall
   *                 collect `numTasks` requests to succeed.
   */
  private class ContextBarrierState(
      val barrierId: ContextBarrierId,
      val numTasks: Int) {

    // There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used
    // to identify each barrier() call. It shall get increased when a barrier() call succeeds, or
    // reset when a barrier() call fails due to timeout.
    private var barrierEpoch: Int = 0

    // An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call
    private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks)

    // Messages from each barrier task that have made a blocking runBarrier() call.
    // The messages will be replied to all tasks once sync finished.
    private val messages = Array.ofDim[String](numTasks)

    // Request methods collected from tasks inside this barrier sync. All tasks should make sure
    // that they're calling the same method within the same barrier sync phase. In other words,
    // the size of requestMethods should always be 1 for a legitimate barrier sync. Otherwise,
    // the barrier sync would fail if the size of requestMethods becomes greater than 1.
    private val requestMethods = new HashSet[RequestMethod.Value]

    // A timer task that ensures we may timeout for a barrier() call.
    private var timerTask: TimerTask = null

    // Init a TimerTask for a barrier() call.
    private def initTimerTask(state: ContextBarrierState): Unit = {
      timerTask = new TimerTask {
        override def run(): Unit = state.synchronized {
          // Timeout current barrier() call, fail all the sync requests.
          requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " +
            s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " +
            s"$timeoutInSecs second(s).")))
          cleanupBarrierStage(barrierId)
        }
      }
    }

    // Cancel the current active TimerTask and release resources.
    private def cancelTimerTask(): Unit = {
      if (timerTask != null) {
        timerTask.cancel()
        timer.purge()
        timerTask = null
      }
    }

    // Process the global sync request. The barrier() call succeed if collected enough requests
    // within a configured time, otherwise fail all the pending requests.
    def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized {
      val taskId = request.taskAttemptId
      val epoch = request.barrierEpoch
      val curReqMethod = request.requestMethod
      requestMethods.add(curReqMethod)
      if (requestMethods.size > 1) {
        val error = new SparkException(s"Different barrier sync types found for the " +
          s"sync $barrierId: ${requestMethods.mkString(", ")}. Please use the " +
          s"same barrier sync type within a single sync.")
        (requesters :+ requester).foreach(_.sendFailure(error))
        clear()
        return
      }

      // Require the number of tasks is correctly set from the BarrierTaskContext.
      require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " +
        s"${request.numTasks} from Task $taskId, previously it was $numTasks.")

      // Check whether the epoch from the barrier tasks matches current barrierEpoch.
      logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.")
      if (epoch != barrierEpoch) {
        requester.sendFailure(new SparkException(s"The request to sync of $barrierId with " +
          s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " +
          "properly killed."))
      } else {
        // If this is the first sync message received for a barrier() call, start timer to ensure
        // we may timeout for the sync.
        if (requesters.isEmpty) {
          initTimerTask(this)
          timer.schedule(timerTask, timeoutInSecs * 1000)
        }
        // Add the requester to array of RPCCallContexts pending for reply.
        requesters += requester
        messages(request.partitionId) = request.message
        logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
          s"$taskId, current progress: ${requesters.size}/$numTasks.")
        if (requesters.size == numTasks) {
          requesters.foreach(_.reply(messages))
          // Finished current barrier() call successfully, clean up ContextBarrierState and
          // increase the barrier epoch.
          logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " +
            s"tasks, finished successfully.")
          barrierEpoch += 1
          requesters.clear()
          requestMethods.clear()
          cancelTimerTask()
        }
      }
    }

    // Cleanup the internal state of a barrier stage attempt.
    def clear(): Unit = synchronized {
      // The global sync fails so the stage is expected to retry another attempt, all sync
      // messages come from current stage attempt shall fail.
      barrierEpoch = -1
      requesters.clear()
      cancelTimerTask()
    }
  }

  // Clean up the [[ContextBarrierState]] that correspond to a specific stage attempt.
  private def cleanupBarrierStage(barrierId: ContextBarrierId): Unit = {
    val barrierState = states.remove(barrierId)
    if (barrierState != null) {
      barrierState.clear()
    }
  }

  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _, _, _, _) =>
      // Get or init the ContextBarrierState correspond to the stage attempt.
      val barrierId = ContextBarrierId(stageId, stageAttemptId)
      states.computeIfAbsent(barrierId,
        (key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
      val barrierState = states.get(barrierId)

      barrierState.handleRequest(context, request)
  }

  private val clearStateConsumer = new Consumer[ContextBarrierState] {
    override def accept(state: ContextBarrierState) = state.clear()
  }
}

private[spark] sealed trait BarrierCoordinatorMessage extends Serializable

/**
 * A global sync request message from BarrierTaskContext. Each request is
 * identified by stageId + stageAttemptId + barrierEpoch.
 *
 * @param numTasks The number of global sync requests the BarrierCoordinator shall receive
 * @param stageId ID of current stage
 * @param stageAttemptId ID of current stage attempt
 * @param taskAttemptId Unique ID of current task
 * @param barrierEpoch ID of a runBarrier() call, a task may consist multiple runBarrier() calls
 * @param partitionId ID of the current partition the task is assigned to
 * @param message Message sent from the BarrierTaskContext
 * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
 */
private[spark] case class RequestToSync(
  numTasks: Int,
  stageId: Int,
  stageAttemptId: Int,
  taskAttemptId: Long,
  barrierEpoch: Int,
  partitionId: Int,
  message: String,
  requestMethod: RequestMethod.Value) extends BarrierCoordinatorMessage

private[spark] object RequestMethod extends Enumeration {
  val BARRIER, ALL_GATHER = Value
}
